package com.xmzs.common.chat.service.impl;

import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.xmzs.common.chat.config.LocalCache;
import com.xmzs.common.chat.entity.chat.ChatCompletion;
import com.xmzs.common.chat.entity.chat.Message;
import com.xmzs.common.chat.listener.OpenAISSEEventSourceListener;
import com.xmzs.common.chat.openai.OpenAiStreamClient;
import com.xmzs.common.chat.service.SseService;
import com.xmzs.common.chat.domain.request.ChatRequest;
import com.xmzs.common.core.domain.model.LoginUser;
import com.xmzs.common.core.exception.base.BaseException;
import com.xmzs.common.core.utils.StringUtils;
import com.xmzs.common.satoken.utils.LoginHelper;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;

/**
 * 描述：
 *
 * @author https:www.unfbx.com
 * @date 2023-04-08
 */
@Service
@Slf4j
public class SseServiceImpl implements SseService {

    private final OpenAiStreamClient openAiStreamClient;

    public SseServiceImpl(OpenAiStreamClient openAiStreamClient) {
        this.openAiStreamClient = openAiStreamClient;
    }


    @Override
    public ResponseBodyEmitter sseChat(ChatRequest chatRequest) {
        if (StrUtil.isBlank(chatRequest.getPrompt())) {
            throw new BaseException("参数异常，msg不能为空~");
        }
        // 上下文信息
        LinkedList<Message> messages = new LinkedList<>();
        // 判断是否需要携带上下文信息
        if(chatRequest.getUsingContext()){
            // 获取对话记录
            messages = LocalCache.getUserChatMessages(chatRequest.getConversationId(), chatRequest.getContentNumber());
        }
        // 设置系统角色
        Message systemMessage = Message.builder().content(chatRequest.getSystemMessage()).role(Message.Role.SYSTEM).build();
        messages.addFirst(systemMessage);
        // 添加本次消息记录
        Message message = Message.builder().content(chatRequest.getPrompt()).role(Message.Role.USER).build();
        messages.add(message);
        ResponseBodyEmitter sseEmitter = getResponseBodyEmitter(chatRequest);
        OpenAISSEEventSourceListener openAIEventSourceListener = new OpenAISSEEventSourceListener(sseEmitter);
        ChatCompletion completion = ChatCompletion
                .builder()
                .messages(messages)
                .model(chatRequest.getModel())
                .user(chatRequest.getConversationId())
                .temperature(chatRequest.getTemperature())
                .topP(chatRequest.getTop_p())
                .stream(true)
                .build();
        openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener);
        LocalCache.MESSAGE.put(chatRequest.getConversationId(), messages);

        return sseEmitter;
    }


    /**
     * 创建sseEmitter
     * @param chatRequest
     * @return
     */
    private static ResponseBodyEmitter getResponseBodyEmitter(ChatRequest chatRequest) {

        ResponseBodyEmitter sseEmitter = new ResponseBodyEmitter(0L);

        sseEmitter.onCompletion(() -> {
            log.info("[{}]结束连接...................", chatRequest.getConversationId());
            LocalCache.CACHE.remove(chatRequest.getConversationId());
        });

        //超时回调
        sseEmitter.onTimeout(() -> {
            log.error("[{}]连接超时...................", chatRequest.getConversationId());
        });

        //异常回调
        sseEmitter.onError(
            throwable -> {
                log.error("[{}]连接失败...................", chatRequest.getConversationId());
            }
        );

        LocalCache.CACHE.put(chatRequest.getConversationId(), sseEmitter);

        log.info("[{}]创建sse连接成功！", chatRequest.getConversationId());

        return sseEmitter;
    }
}
